﻿// ------------------------------------------------------------------------
//  Copyright 2025 The Dapr Authors
//  Licensed under the Apache License, Version 2.0 (the "License");
//  you may not use this file except in compliance with the License.
//  You may obtain a copy of the License at
//      http://www.apache.org/licenses/LICENSE-2.0
//  Unless required by applicable law or agreed to in writing, software
//  distributed under the License is distributed on an "AS IS" BASIS,
//  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
//  See the License for the specific language governing permissions and
//  limitations under the License.
//  ------------------------------------------------------------------------

using Dapr.AI.Conversation.ConversationRoles;
using Dapr.AI.Conversation.Extensions;
using Dapr.AI.Conversation.Tools;
using Dapr.Common.Extensions;
using Google.Protobuf.WellKnownTypes;
using Autogenerated = Dapr.Client.Autogen.Grpc.v1;

namespace Dapr.AI.Conversation;

/// <summary>
/// Prototype utilities used for mapping domain to protobuf types for the Conversation functionality.
/// </summary>
public static class ConversationProtoUtilities
{
    /// <summary>
    /// Converts an <see cref="IEnumerable{MessageContent}"/> into its protobuf equivalent.
    /// </summary>
    /// <param name="contents">The contents to map.</param>
    /// <returns></returns>
    public static IEnumerable<Autogenerated.ConversationMessageContent> ToProtoContents(
        this IEnumerable<MessageContent> contents) =>
        contents.Select(c => new Autogenerated.ConversationMessageContent { Text = c.Text });

    /// <summary>
    /// Converts an <see cref="IReadOnlyList{ConversationInput}"/> into its protobuf equivalent.
    /// </summary>
    /// <param name="inputs">The contents to map.</param>
    /// <returns></returns>
    private static IEnumerable<Autogenerated.ConversationInputAlpha2> ToProto(this IReadOnlyList<ConversationInput> inputs) =>
        inputs.Select(input =>
        {
            var protoMessages = input.Messages.Select(msg => msg.ToProto());
            var messages = protoMessages.ToRepeatedField();
            
            var output = new Autogenerated.ConversationInputAlpha2();
            output.Messages.AddRange(messages);

            if (input.ScrubPII is not null)
            {
                output.ScrubPii = input.ScrubPII.Value;
            }

            return output;
        });

    /// <summary>
    /// Creates an <see cref="Autogenerated.ConversationRequestAlpha2"/> from the provided inputs and options.
    /// </summary>
    /// <param name="inputs">The conversation inputs.</param>
    /// <param name="options">The conversation options.</param>
    /// <returns></returns>
    /// <exception cref="NotSupportedException"></exception>
    public static Autogenerated.ConversationRequestAlpha2 CreateConversationInputRequest(IReadOnlyList<ConversationInput> inputs,
        ConversationOptions options)
    {
        var request = new Autogenerated.ConversationRequestAlpha2 { Name = options.ConversationComponentId };

        if (options.ContextId is not null)
        {
            request.ContextId = options.ContextId;
        }

        if (options.ToolChoice is not null)
        {
            request.ToolChoice = options.ToolChoice.ToString();
        }

        var protoInputs = inputs.ToProto();
        request.Inputs.AddRange(protoInputs);
        
        foreach (var p in options.Parameters)
        {
            request.Parameters.Add(p.Key, p.Value);
        }

        foreach (var m in options.Metadata)
        {
            request.Metadata.Add(m.Key, m.Value);
        }
        
        if (options.ScrubPII is not null)
        {
            request.ScrubPii = options.ScrubPII.Value;
        }

        if (options.Temperature is not null)
        {
            request.Temperature = options.Temperature.Value;
        }

        if (options.Tools.Count > 0)
        {
            var tools = options.Tools.Select(tool =>
            {
                switch (tool)
                {
                    case ToolFunction toolF:
                    {
                        var toolFunction = new Autogenerated.ConversationToolsFunction { Name = toolF.Name };

                        if (toolF.Description is not null)
                        {
                            toolFunction.Description = toolF.Description;
                        }
                        
                        var parametersStruct = new Struct();
                        foreach (var (k, v) in toolF.Parameters)
                        {
                            parametersStruct.Fields[k] = ProtobufHelpers.ToValue(v);
                        }

                        toolFunction.Parameters = parametersStruct;

                        return new Autogenerated.ConversationTools { Function = toolFunction };
                    }
                    default:
                        throw new NotSupportedException($"Unsupported tool type: {tool.GetType().FullName}");
                }
            }).ToRepeatedField();
            request.Tools.AddRange(tools);
        }

        return request;
    }

    /// <summary>
    /// Maps the <see cref="Autogenerated.ConversationResponseAlpha2"/> to a <see cref="ConversationResponse"/>.
    /// </summary>
    /// <param name="result">The result from the conversation API to parse.</param>
    /// <returns></returns>
    public static ConversationResponse ToDomain(this Autogenerated.ConversationResponseAlpha2 result)
    {
        string? contextId = result.ContextId;
        var conversationResults = result.Outputs.Select(convoResult =>
        {
            var choices = convoResult.Choices.Select(c =>
            {
                var didParseReason = c.FinishReason.TryParseEnumMember<FinishReason>(out var parsedFinishReason);
                var resultMessage = new ResultMessage(c.Message.Content)
                {
                    ToolCalls = c.Message.ToolCalls
                        .Select(ToolCallBase (tc) =>
                            new CalledToolFunction(tc.Function.Name, tc.Function.Arguments)
                            {
                                Id = tc.Id
                            })
                        .ToList()
                };

                return new ConversationResultChoice(didParseReason ? parsedFinishReason : null,
                    c.Index, resultMessage);
            }).ToList();
            return new ConversationResponseResult(choices);
        }).ToList();

        return new ConversationResponse(conversationResults, contextId);
    }

    /// <summary>
    /// Converts an <see cref="IConversationMessage"/> into its protobuf equivalent.
    /// </summary>
    /// <param name="message">The message to convert.</param>
    /// <returns></returns>
    /// <exception cref="ArgumentException"></exception>
    /// <exception cref="NotImplementedException"></exception>
    private static Autogenerated.ConversationMessage ToProto(this IConversationMessage message)
    {
        var messageContents = message.Content
            .Select(msg => new Autogenerated.ConversationMessageContent { Text = msg.Text })
            .ToList()
            .ToRepeatedField();

        switch (message)
        {
            case DeveloperMessage devMessage:
            {
                var output = new Autogenerated.ConversationMessageOfDeveloper();
                if (devMessage.Name is not null)
                {
                    output.Name = devMessage.Name;
                }
                
                output.Content.AddRange(messageContents);

                return new Autogenerated.ConversationMessage { OfDeveloper = output };
            }
            case UserMessage userMessage:
            {
                var output = new Autogenerated.ConversationMessageOfUser();
                if (userMessage.Name is not null)
                {
                    output.Name = userMessage.Name;
                }
                output.Content.AddRange(messageContents);

                return new Autogenerated.ConversationMessage { OfUser = output };
            }
            case AssistantMessage assistantMessage:
            {
                var output = new Autogenerated.ConversationMessageOfAssistant();
                if (assistantMessage.Name is not null)
                {
                    output.Name = assistantMessage.Name;
                }
                output.Content.AddRange(messageContents);

                var toolContent = assistantMessage.ToolCalls.Select(toolCall =>
                {
                    if (toolCall is CalledToolFunction funcToolCall)
                    {
                        return new Autogenerated.ConversationToolCalls
                        {
                            Id = funcToolCall.Id,
                            Function = new Autogenerated.ConversationToolCallsOfFunction
                            {
                                Name = funcToolCall.Name, Arguments = funcToolCall.JsonArguments
                            }
                        };
                    }

                    throw new ArgumentException($"Unrecognized tool call type for identifier '{toolCall.Id}'");
                });
                output.ToolCalls.AddRange(toolContent);

                return new Autogenerated.ConversationMessage { OfAssistant = output };
            }
            case SystemMessage systemMessage:
            {
                var output = new Autogenerated.ConversationMessageOfSystem();
                if (systemMessage.Name is not null)
                {
                    output.Name = systemMessage.Name;
                }
                output.Content.AddRange(messageContents);

                return new Autogenerated.ConversationMessage { OfSystem = output };
            }
            case ToolMessage toolMessage:
            {
                var output = new Autogenerated.ConversationMessageOfTool
                {
                    Name = toolMessage.Name
                };

                if (toolMessage.Id is not null)
                {
                    output.ToolId = toolMessage.Id;
                }
                
                output.Content.AddRange(messageContents);

                return new Autogenerated.ConversationMessage { OfTool = output };
            }
            default:
                throw new NotImplementedException("Message type not recognized.");
        }
    }
}
